﻿using Apewer.Network;
using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;

namespace Apewer.Web
{

    /// <summary>Socket 连接。</summary>
    public sealed class MiniConnection
    {

        // 超时毫秒数。
        const int RequestHeadTimeout = 15 * 1000;
        const int RequestBodyTimeout = 3600 * 1000;
        const int ResponseHeadTimeout = 15 * 1000;
        const int ResponseBodyTimeout = 3600 * 1000;
        const int StreamReadTimeout = 3 * 1000;
        const int StreamWriteTimeout = 3 * 1000;

        // 缓冲。
        const int BufferSize = 8192;
        const int HeadMax = 32768;

        // 持久变量。
        Timer _timer;
        int _reuses = 0;

        // 临时连接变量，每次请求时初始化。
        MiniContext _context;
        ArrayBuilder<byte> _head = null;
        byte[] _buffer;
        MiniReader _instream = null;
        MiniWriter _outstream = null;
        bool _headsent = false;
        bool _bodysent = false;
        bool _forceClose = false;
        bool _chunked = false;
        int _headread = 0;

        internal MiniConnection(MiniServer server, Socket socket)
        {
            _server = server;
            _socket = socket;

            // init
            _netstream = new NetworkStream(socket, false);
            _netstream.ReadTimeout = StreamReadTimeout;
            _netstream.WriteTimeout = StreamWriteTimeout;
            _local = socket.LocalEndPoint as IPEndPoint;
            _remote = socket.RemoteEndPoint as IPEndPoint;
            _stream = _netstream;

            // SSL 证书。
            var certificate = server.SslCertificate;
            if (certificate != null)
            {
                _sslstream = CreateSslStream(_netstream, certificate, server.SslProtocols);
                _stream = _sslstream;
            }
        }

        void Timeout(int duration = System.Threading.Timeout.Infinite)
        {
            const int Infinite = System.Threading.Timeout.Infinite;
            if (duration > 0)
            {
                if (_timer == null) _timer = new Timer(obj => CloseSocket(), null, Infinite, Infinite);
                else _timer.Change(duration, Infinite);
            }
            else
            {
                if (_timer != null)
                {
                    _timer.Change(Infinite, Infinite);
                    _timer.Dispose();
                    _timer = null;
                }
            }
        }

        #region socket

        MiniServer _server = null;
        Socket _socket = null;
        Stream _stream;
        NetworkStream _netstream;
        SslStream _sslstream;
        IPEndPoint _local;
        IPEndPoint _remote;

        /// <summary>服务器。</summary>
        public MiniServer Server { get => _server; }

        /// <summary>本地网络终结点。</summary>
        public IPEndPoint LocalEndPoint { get => _local; }

        /// <summary>远端网络终结点。</summary>
        public IPEndPoint RemoteEndPoint { get => _remote; }

        internal Stream Stream { get => _stream; }

        void CloseSocket()
        {
            Timeout();

            if (_sslstream != null)
            {
                try { _sslstream.Close(); } catch { }
                try { _sslstream.Dispose(); } catch { }
            }

            if (_netstream != null)
            {
                try { _netstream.Close(); } catch { }
                try { _netstream.Dispose(); } catch { }
            }

            if (_socket != null)
            {
                try { _socket.LingerState = new LingerOption(false, 0); } catch { }
                try { _socket.Shutdown(SocketShutdown.Both); } catch { }
                try { _socket.Disconnect(false); } catch { }
                try { _socket.Close(1); } catch { }
                _socket = null;
            }
        }

        /// <summary>关闭 Socket 连接。</summary>
        public void Close(bool force = false)
        {
            if (_socket == null) return;
            SendHead();
            SendBody();

            if (force)
            {
                CloseSocket();
                return;
            }

            if (!force && !_context.Request.KeepAlive) force = true;
            if (!force && TextUtility.Lower(_context.Response.Headers.GetValue("connection", true)) == "close") force = true;
            if (force)
            {
                CloseSocket();
                return;
            }

            _reuses++;
            BeginRead();
        }

        #endregion

        #region request

        // 发起读取。
        internal void BeginRead()
        {
            // reset
            _context = new MiniContext(this);
            _head = new ArrayBuilder<byte>(BufferSize);
            if (_buffer == null) _buffer = new byte[BufferSize];
            _instream = null;
            _outstream = null;
            _headsent = false;
            _bodysent = false;
            _forceClose = false;
            _chunked = _server.Chunked;
            _headread = 0;

            try
            {
                Timeout(RequestHeadTimeout);
                if (_server.SynchronousIO)
                {
                    SyncRead();
                    ProcessHead();
                }
                else
                {
                    _stream.BeginRead(_buffer, 0, BufferSize, ar2 => ((MiniConnection)ar2.AsyncState).ProcessHead(ar2), this);
                }
            }
            catch
            {
                Timeout();
                CloseSocket();
            }
        }

        void SyncRead()
        {
            while (true)
            {
                var nread = _stream.Read(_buffer, 0, BufferSize);
                if (nread > 0) _head.Add(_buffer, 0, nread);
                else throw new Exception("未接收到完整的请求头。");

                var head = ReadHead();
                if (head) return;

                if (_head.Count > HeadMax)
                {
                    Send(400);
                    Close(true);
                    return;
                }
            }
        }

        // 读取头，返回标记表示已经完成头的读取。
        bool ReadHead()
        {
            var count = _head.Count;
            if (count < 5) return false;

            var array = _head.Origin;
            var end = count - 3;
            if (_headread > 3) _headread -= 3;
            for (var i = _headread; i < end; i++)
            {
                if (array[i] != 13) continue;
                if (array[i + 1] != 10) continue;
                if (array[i + 2] != 13) continue;
                if (array[i + 3] != 10) continue;

                // head
                var head = new byte[i];
                Buffer.BlockCopy(array, 0, head, 0, i);
                var text = Encoding.ASCII.GetString(head);
                var lines = text.Split('\r', '\n');
                var first = true;
                foreach (var line in lines)
                {
                    if (line.IsEmpty()) continue;
                    if (first)
                    {
                        var split = line.Split(' ');
                        var segs = new List<string>(3);
                        foreach (var seg in split)
                        {
                            if (string.IsNullOrEmpty(seg)) continue;
                            segs.Add(seg);
                        }

                        // GET /index.html HTTP/1.1
                        if (segs.Count != 3)
                        {
                            Send(400);
                            Close(true);
                            throw new Exception("请求的报文无效。");
                        }

                        _context.Request.Method = split[0];
                        _context.Request.Path = split[1];
                        _context.Request.Version = split[2];
                    }
                    else
                    {
                        var colon = line.IndexOf(":");
                        if (colon > 1)
                        {
                            var name = TextUtility.Trim(line.Substring(0, colon));
                            var value = TextUtility.Trim(line.Substring(colon + 1));
                            if (!string.IsNullOrEmpty(value) && !string.IsNullOrEmpty(value))
                            {
                                _context.Request.Headers.Add(name, value);
                            }
                        }
                    }
                    first = false;
                }

                // body
                var remains = count - i - 4;
                var reader = GetRequestStream();
                if (remains > 0)
                {
                    var contentLength = _context.Request.ContentLength;
                    if (contentLength > 0)
                    {
                        var body = new byte[remains];
                        Buffer.BlockCopy(array, i + 4, body, 0, remains);
                        reader.RemainsBytes = body;
                    }
                }

                return true;
            }
            return false;
        }

        void ProcessHead(IAsyncResult ar)
        {
            if (_socket == null) return;

            // 读取头。
            while (true)
            {
                int nread = -1;
                try
                {
                    nread = _stream.EndRead(ar);
                }
                catch
                {
                    // 读取失败，发送 400，并关闭 Socket。
                    Send(400);
                    Close(true);
                    return;
                }

                // 未读取到内容，表示已完成。
                if (nread < 1) break;

                // 将读取的数据加入缓冲。
                _head.Add(_buffer, 0, nread);

                // 缓冲区容量已溢出。
                if (_head.Length > HeadMax)
                {
                    Send(400);
                    Close(true);
                    return;
                }

                // 继续读取。
                if (nread < BufferSize) break;
                _stream.BeginRead(_buffer, 0, BufferSize, ar2 => ((MiniConnection)ar2.AsyncState).ProcessHead(ar2), this);
                return;
            }

            ProcessHead();
        }

        void ProcessHead()
        {
            // 处理过程中禁用计时器。
            Timeout();

            // 检查读取的数据。
            if (_head.Count < 1)
            {
                Close(true);
                return;
            }

            // Expect: 100-continue

            // 解析头。
            var bytes = _head.Export();


            // 检查 HTTP 方法。
            var method = ParseMethod(_context.Request.Method);
            if (method == HttpMethod.NULL)
            {
                Send(405);
                Close(true);
                return;
            }

            // 检查 HTTP 协议版本。
            var version = TextUtility.Upper(_context.Request.Version);
            switch (version)
            {
                case "HTTP/1":
                case "HTTP/1.0":
                    _context.Request.Http11 = false;
                    break;
                case "HTTP/1.1":
                    _context.Request.Http11 = true;
                    break;
                default:
                    Send(505);
                    Close(true);
                    return;
            }

            // 保持连接。
            _context.Request.KeepAlive = _context.Request.Http11 && TextUtility.Lower(_context.Request.Headers.GetValue("Connection")) == "keep-alive";

            // 启用压缩。
            if (_server.Compression)
            {
                var headerValue = _context.Request.Headers.GetValue("Accept-Encoding");
                if (!string.IsNullOrEmpty(headerValue))
                {
                    var split = headerValue.ToLower().Split(',');
                    foreach (var seg in split)
                    {
                        switch (seg)
                        {
                            case "gzip": _context.Request.Gzip = true; break;
                            case "brotli": _context.Request.Brotli = true; break;
                        }
                    }
                }
            }

            // URL
            var host = _context.Request.Headers.GetValue("Host");
            var port = 0;
            var local = LocalEndPoint;
            if (local != null)
            {
                if (string.IsNullOrEmpty(host)) host = local.Address.ToString();
                port = local.Port;
            }
            _context.Request.Url = new Uri($"http://{host}{_context.Request.Path}");

            // Handler
            var handler = _server.Handler;
            if (handler == null)
            {
                Send(501);
                Close(true);
                return;
            }

            // Invoke
            Timeout();
            handler.Invoke(_context);
            SendHead();
            SendBody();
        }

        internal void SendContinue()
        {
            Timeout(ResponseHeadTimeout);
            var http11 = _context.Request.Http11;
            var text = http11 ? "HTTP/1.1 100 Continue\r\n\r\n" : "HTTP/1.0 100 Continue\r\n\r\n";
            var bytes = Encoding.ASCII.GetBytes(text);
            _socket.Send(bytes);
        }

        internal MiniReader GetRequestStream()
        {
            if (_instream != null) return _instream;

            var headers = _context.Request.Headers;
            var length = -1L;
            var value = headers.GetValue("Content-Length");
            if (!string.IsNullOrEmpty(value))
            {
                var num = value.Int64();
                if (num.ToString() == value) length = num;
            }
            _instream = new MiniReader(this, length);

            Timeout(RequestBodyTimeout);
            return _instream;
        }

        #endregion

        #region response

        /// <summary>对响应体分块。</summary>
        public bool Chunked { get => _chunked && _context.Request.Http11; }

        internal void SendHead()
        {
            if (_headsent) return; else _headsent = true;
            _instream?.Close();

            var lines = new ArrayBuilder<string>(32);
            var response = _context.Response;
            var headers = _context.Response.Headers;

            // redirect
            var location = response.Location;
            if (string.IsNullOrEmpty(location)) location = null;

            // version
            var isHttp11 = _context.Request.Http11;
            var version = isHttp11 ? "1.1" : "1.0";

            // status
            var status = response.Status;
            if (status == 0) status = 200;
            if (location != null) status = 302;
            var statusDesc = NetworkUtility.HttpStatusDescription(status);
            if (statusDesc.IsEmpty()) statusDesc = "OK";
            lines.Add($"HTTP/{version} {status} {statusDesc}");

            // keep-alive
            var forceClose = true;
            var keepAlive = false;
            if (isHttp11)
            {
                if (status == 400 || status == 408 || status == 411 || status == 413 || status == 414 || status == 500 || status == 503) forceClose = true;
                if (_reuses > 128) forceClose = true;

                // request: keep-alive
                keepAlive = _context.Request.KeepAlive;
                if (!keepAlive) forceClose = true;

                // response: keep-alive
                if (!_context.Response.KeepAlive) forceClose = true;

                if (keepAlive)
                {
                    if (forceClose)
                    {
                        lines.Add("Connection: close");
                    }
                    else
                    {
                        lines.Add("Connection: keep-alive");
                        lines.Add($"Keep-Alive: timeout=15, max={128 - _reuses}");
                    }
                }
                if (Chunked) lines.Add("Transfer-Encoding: chunked");
            }

            // redirect
            if (location != null) lines.Add($"Location: {location}");

            // content-type
            var contentType = response.ContentType;
            if (string.IsNullOrEmpty(contentType)) contentType = headers.GetValue("content-type", true);
            if (!string.IsNullOrEmpty(contentType)) lines.Add("Content-Type:" + contentType);

            // content-length
            var contentLength = response.ContentLength;
            if (contentLength >= 0) lines.Add("Content-length:" + contentLength);

            // 自定义头
            foreach (var header in headers)
            {
                var key = header.Key;
                if (key.IsEmpty()) continue;

                var lower = key.Lower();
                if (lower == "content-length") continue;
                if (lower == "content-type") continue;
                if (keepAlive)
                {
                    if (lower == "connection") continue;
                    if (lower == "keep-alive") continue;
                }
                if (location == null)
                {
                    if (lower == "location") continue;
                }

                var value = header.Value;
                if (value.IsEmpty()) continue;

                lines.Add(key + ":" + value);
            }

            // send
            Timeout(ResponseHeadTimeout);
            var text = string.Join("\r\n", lines.Export());
            text += "\r\n\r\n";
            var bytes = TextUtility.Bytes(text);
            _socket?.Send(bytes);
            _forceClose = forceClose;
        }

        internal void SendBody()
        {
            if (_bodysent) return; else _bodysent = true;
            _outstream?.Close();
            Close(_forceClose);
        }

        internal MiniWriter GetResponseStream()
        {
            if (_outstream != null) return _outstream;

            SendHead();
            _outstream = new MiniWriter(this);

            Timeout(ResponseBodyTimeout);
            return _outstream;
        }

        /// <summary>发送状态码，并关闭 Socket 连接。</summary>
        internal void Send(int status)
        {
            var response = _context.Response;
            response.Status = status;
            response.Close();
        }

        #endregion

        #region static

        static HttpMethod ParseMethod(string method)
        {
            if (!string.IsNullOrEmpty(method))
            {
                var upper = TextUtility.Upper(method);
                if (upper.Contains("OPTIONS")) return HttpMethod.OPTIONS;
                else if (upper.Contains("POST")) return HttpMethod.POST;
                else if (upper.Contains("GET")) return HttpMethod.GET;
                else if (upper.Contains("CONNECT")) return HttpMethod.CONNECT;
                else if (upper.Contains("DELETE")) return HttpMethod.DELETE;
                else if (upper.Contains("HEAD")) return HttpMethod.HEAD;
                else if (upper.Contains("PATCH")) return HttpMethod.PATCH;
                else if (upper.Contains("PUT")) return HttpMethod.PUT;
                else if (upper.Contains("TRACE")) return HttpMethod.TRACE;
            }
            return HttpMethod.NULL;
        }

        static SslStream CreateSslStream(Stream stream, X509Certificate certificate, SslProtocols protocols)
        {
            var sslStream = new SslStream(stream, false, (t, c, ch, e) =>
            {
                if (c == null) return true;
                var c2 = c as X509Certificate2;
                if (c2 == null) c2 = new X509Certificate2(c.GetRawCertData());
                return true;
            });
            sslStream.AuthenticateAsServer(certificate, true, protocols, false);
            return sslStream;
        }

        /// <summary>证书验证。忽略所有错误。</summary>
        static bool ApproveAll(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors errors)
        {
            return true;
        }

        /// <summary>证书验证。</summary>
        static X509Certificate ApproveFirst(object sender, string targetHost, X509CertificateCollection localCertificates, X509Certificate remoteCertificate, string[] acceptableIssuers)
        {
            if (localCertificates != null)
            {
                for (var i = 0; i < localCertificates.Count; i++)
                {
                    var certificate = localCertificates[i];
                    if (certificate != null) return certificate;
                }
            }
            return null;
        }

        static long GetContentLength(StringPairs headers)
        {
            if (headers != null)
            {
                var value = headers.GetValue("Content-Length", true);
                if (!string.IsNullOrEmpty(value))
                {
                    var num = value.Int64();
                    if (num.ToString() == value) return num;
                }
            }
            return -1L;
        }

        #endregion

    }

}
